{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "166c742c-aa92-4c4f-93a2-876f6aa395bb",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# 示例：缓存及加载功能\n",
    "- 可以缓存及加载任何对象（数据、模型……）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "42841ec6-63a2-40aa-9621-fc609d67a8c1",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from fastphm.data.FeatureExtractor import FeatureExtractor\n",
    "from fastphm.data.labeler.RulLabeler import RulLabeler\n",
    "from fastphm.data.loader.bearing.XJTULoader import XJTULoader\n",
    "from fastphm.data.processor.RMSProcessor import RMSProcessor\n",
    "from fastphm.data.stage.BearingStageCalculator import BearingStageCalculator\n",
    "from fastphm.data.stage.fpt.ThreeSigmaFPTCalculator import ThreeSigmaFPTCalculator\n",
    "from fastphm.util.Cache import Cache\n",
    "from fastphm.model.pytorch.PytorchModel import PytorchModel\n",
    "from fastphm.model.pytorch.basic.MLP import MLP"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e1967380-798e-4df3-bd88-b0aecd898184",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### 生成数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ab1339ac-fe28-4cf4-92a8-07d3c2c71c76",
   "metadata": {
    "scrolled": true,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "DEBUG - 17:43:54 >> \n",
      "<< Root directory of dataset: D:\\data\\dataset\\XJTU-SY_Bearing_Datasets\n",
      "\tBearing1_1, location: D:\\data\\dataset\\XJTU-SY_Bearing_Datasets\\35Hz12kN\\Bearing1_1\n",
      "\tBearing1_2, location: D:\\data\\dataset\\XJTU-SY_Bearing_Datasets\\35Hz12kN\\Bearing1_2\n",
      "\tBearing1_3, location: D:\\data\\dataset\\XJTU-SY_Bearing_Datasets\\35Hz12kN\\Bearing1_3\n",
      "\tBearing1_4, location: D:\\data\\dataset\\XJTU-SY_Bearing_Datasets\\35Hz12kN\\Bearing1_4\n",
      "\tBearing1_5, location: D:\\data\\dataset\\XJTU-SY_Bearing_Datasets\\35Hz12kN\\Bearing1_5\n",
      "\tBearing2_1, location: D:\\data\\dataset\\XJTU-SY_Bearing_Datasets\\37.5Hz11kN\\Bearing2_1\n",
      "\tBearing2_2, location: D:\\data\\dataset\\XJTU-SY_Bearing_Datasets\\37.5Hz11kN\\Bearing2_2\n",
      "\tBearing2_3, location: D:\\data\\dataset\\XJTU-SY_Bearing_Datasets\\37.5Hz11kN\\Bearing2_3\n",
      "\tBearing2_4, location: D:\\data\\dataset\\XJTU-SY_Bearing_Datasets\\37.5Hz11kN\\Bearing2_4\n",
      "\tBearing2_5, location: D:\\data\\dataset\\XJTU-SY_Bearing_Datasets\\37.5Hz11kN\\Bearing2_5\n",
      "\tBearing3_1, location: D:\\data\\dataset\\XJTU-SY_Bearing_Datasets\\40Hz10kN\\Bearing3_1\n",
      "\tBearing3_2, location: D:\\data\\dataset\\XJTU-SY_Bearing_Datasets\\40Hz10kN\\Bearing3_2\n",
      "\tBearing3_3, location: D:\\data\\dataset\\XJTU-SY_Bearing_Datasets\\40Hz10kN\\Bearing3_3\n",
      "\tBearing3_4, location: D:\\data\\dataset\\XJTU-SY_Bearing_Datasets\\40Hz10kN\\Bearing3_4\n",
      "\tBearing3_5, location: D:\\data\\dataset\\XJTU-SY_Bearing_Datasets\\40Hz10kN\\Bearing3_5\n",
      "INFO - 17:43:54 >> Loading data entity: Bearing1_3\n",
      "INFO - 17:43:57 >> Successfully loaded data entity: Bearing1_3\n"
     ]
    }
   ],
   "source": [
    "data_loader = XJTULoader('D:\\\\data\\\\dataset\\\\XJTU-SY_Bearing_Datasets')\n",
    "feature_extractor = FeatureExtractor(RMSProcessor(data_loader.continuum))\n",
    "fpt_calculator = ThreeSigmaFPTCalculator()\n",
    "stage_calculator = BearingStageCalculator(data_loader.continuum, fpt_calculator)\n",
    "\n",
    "# 获取原始数据、特征数据、阶段数据\n",
    "bearing = data_loader(\"Bearing1_3\", 'Horizontal Vibration')\n",
    "feature_extractor(bearing)\n",
    "stage_calculator(bearing)\n",
    "\n",
    "generator = RulLabeler(2048, time_ratio=60, is_from_fpt=False, is_rectified=True)\n",
    "train_set = generator(bearing)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "53c3e103-9156-47d4-966f-9952bc01183c",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### 缓存数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "328c7310-aba4-4107-8fe7-0dd7bc4ae521",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "DEBUG - 17:43:57 >> Start generating cache file: .\\cache\\train_set.pkl\n",
      "DEBUG - 17:43:57 >> Successfully generated cache file: .\\cache\\train_set.pkl\n"
     ]
    }
   ],
   "source": [
    "Cache.save(train_set, 'train_set')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b341d3a2-7ff6-4af7-8761-9c268b22a615",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### 加载数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "8d288e84-638b-4f11-8f8a-a965ba9a75e3",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "DEBUG - 17:43:57 >> Start loading cache file: .\\cache\\train_set.pkl\n",
      "DEBUG - 17:43:57 >> Successfully loaded cache file: .\\cache\\train_set.pkl\n"
     ]
    }
   ],
   "source": [
    "train_set = Cache.load('train_set')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "7d7aa10c-1349-4b9c-80ec-258dbf604e41",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ -0.77383518,   0.1305342 ,   0.1555443 , ...,   0.1136541 ,\n",
       "         -0.54583549,   0.42616129],\n",
       "       [ -0.38670301,  -0.3725886 ,  -0.74648857, ...,  -0.0345588 ,\n",
       "         -0.70579052,  -0.79065561],\n",
       "       [ -0.27912861,   0.52905079,  -0.2827883 , ...,  -0.63660137,\n",
       "          0.25235411,   0.45748949],\n",
       "       ...,\n",
       "       [ -0.48875809,   0.66200487,  -2.59971619, ..., -11.16659641,\n",
       "         -3.66904736,  -6.93216324],\n",
       "       [ -2.04831362,  -0.154233  ,   3.88391018, ...,  -9.79976654,\n",
       "          1.93562508,  -7.22073317],\n",
       "       [ -3.97603512,   2.60374546,  -0.52655939, ...,  -5.99588156,\n",
       "         -4.60753441, -10.23435593]])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_set.x"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d60e2f81-8146-4d36-aa43-fd7dc4a73c70",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### 生成模型并训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "84290f34-94a4-4b34-9ad0-0d1597824b00",
   "metadata": {
    "scrolled": true,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO - 17:43:57 >> \n",
      "<< Successfully initialized model:\n",
      "\tclass: MLP\n",
      "\tdevice: cuda\n",
      "\tdtype: torch.float32\n",
      "INFO - 17:43:58 >> \n",
      "<< Start training model:\n",
      "\tloss function: MSELoss\n",
      "\toptimizer: Adam\n",
      "\tlearning rate: 0.01\n",
      "\tweight decay: 0.01\n",
      "\tbatch size: 256\n",
      "DEBUG - 17:43:58 >> Epoch 1/10, Loss: 2.9730943799\n",
      "DEBUG - 17:43:58 >> Epoch 2/10, Loss: 2.6427245855\n",
      "DEBUG - 17:43:58 >> Epoch 3/10, Loss: 0.9938627124\n",
      "DEBUG - 17:43:58 >> Epoch 4/10, Loss: 0.4982388645\n",
      "DEBUG - 17:43:58 >> Epoch 5/10, Loss: 0.3173443541\n",
      "DEBUG - 17:43:58 >> Epoch 6/10, Loss: 0.2191797912\n",
      "DEBUG - 17:43:58 >> Epoch 7/10, Loss: 0.1704204008\n",
      "DEBUG - 17:43:58 >> Epoch 8/10, Loss: 0.1305305324\n",
      "DEBUG - 17:43:58 >> Epoch 9/10, Loss: 0.1046425104\n",
      "DEBUG - 17:43:58 >> Epoch 10/10, Loss: 0.0864575386\n",
      "INFO - 17:43:58 >> Model training completed!!!\n"
     ]
    }
   ],
   "source": [
    "model = PytorchModel(MLP(2048,16,1))\n",
    "model.train(train_set, epochs=10, batch_size=256, lr=0.01, weight_decay=0.01)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d746f0ee-09ad-42c0-834d-28b07ad752b5",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### 缓存模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "fb2b25c8-1e1c-4ea0-a8ed-84ef9f060ee6",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "DEBUG - 17:43:58 >> Start generating cache file: .\\cache\\model.pkl\n",
      "DEBUG - 17:43:58 >> Successfully generated cache file: .\\cache\\model.pkl\n"
     ]
    }
   ],
   "source": [
    "Cache.save(model, 'model')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a5f3640-3d0f-4ef7-a526-577fe95b6df3",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "### 加载模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "7f03508b-02f5-440b-9250-9bd405c58f89",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "DEBUG - 17:43:58 >> Start loading cache file: .\\cache\\model.pkl\n",
      "DEBUG - 17:43:58 >> Successfully loaded cache file: .\\cache\\model.pkl\n"
     ]
    }
   ],
   "source": [
    "model = Cache.load('model')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}